package com.ruihu.admin.config;

import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.handlers.AbstractSqlParserHandler;
import com.ruihu.admin.common.domain.OrderItemDTO;
import com.ruihu.admin.common.domain.PageParamDTO;
import com.sun.javafx.collections.MappingChange;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.*;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.StatementType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Intercepts(
        {
                @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        }
)
public class OrderByInterceptor extends AbstractSqlParserHandler implements Interceptor {

    private static final ThreadLocal<String> ORDERBY_LOCAL = new ThreadLocal<String>();

    protected static final Log logger = LogFactory.getLog(OrderByInterceptor.class);

    private static Field additionalParametersField;

    static {
        try {
            additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
            additionalParametersField.setAccessible(true);
        } catch (NoSuchFieldException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) args[0];
        Object parameter = args[1];
        RowBounds rowBounds = (RowBounds) args[2];
        ResultHandler resultHandler = (ResultHandler) args[3];
        Executor executor = (Executor) invocation.getTarget();
        CacheKey cacheKey;
        BoundSql boundSql;

        if (SqlCommandType.SELECT != mappedStatement.getSqlCommandType()
                || StatementType.CALLABLE == mappedStatement.getStatementType()) {
            return invocation.proceed();
        }

        if (args.length == 4) {
            boundSql = mappedStatement.getBoundSql(parameter);
            cacheKey = executor.createCacheKey(mappedStatement, parameter, rowBounds, boundSql);
        } else {
            cacheKey = (CacheKey) args[4];
            boundSql = (BoundSql) args[5];
        }
        // 针对定义了rowBounds，做为mapper接口方法的参数
        Object paramObj = boundSql.getParameterObject();

        // 判断参数里是否有page对象和PageParamDTO对象
        IPage<?> page = null;
        PageParamDTO pageParamDTO = null;
        if (paramObj instanceof IPage) {
            page = (IPage<?>) paramObj;
        } else if (paramObj instanceof PageParamDTO) {
            pageParamDTO = (PageParamDTO) paramObj;
        } else if (paramObj instanceof Map) {
            for (Object arg : ((Map<?, ?>) paramObj).values()) {
                if (arg instanceof IPage && page == null) {
                    page = (IPage<?>) arg;
                } else if (arg instanceof PageParamDTO && pageParamDTO == null) {
                    pageParamDTO = (PageParamDTO) arg;
                }
            }
        }

        // 如果没有设置page，处理order by
        if (null == page && null != pageParamDTO) {
            String originalSql = boundSql.getSql();
            String orderBySql =  concatOrderBy(originalSql, pageParamDTO);
            //更新cacheKey，防止缓存错误#3
            cacheKey.update(orderBySql);
            BoundSql orderbyBoundSql = new BoundSql(mappedStatement.getConfiguration(), orderBySql, boundSql.getParameterMappings(), parameter);
            Map<String, Object> additionalParameters = (Map<String, Object>) additionalParametersField.get(boundSql);
            for (String key : additionalParameters.keySet()) {
                orderbyBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
            }
            return executor.query(mappedStatement, parameter, rowBounds, resultHandler, cacheKey, orderbyBoundSql);
        } else {
            return invocation.proceed();
        }
    }

    public static String concatOrderBy(String originalSql, PageParamDTO pageParamDTO) {
        if (CollectionUtils.isNotEmpty(pageParamDTO.getOrders())) {
            try {
                List<OrderItemDTO> orderList = pageParamDTO.getOrders();
                Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
                if (selectStatement.getSelectBody() instanceof PlainSelect) {
                    PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
                    List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
                    List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
                    plainSelect.setOrderByElements(orderByElementsReturn);
                    return plainSelect.toString();
                } else if (selectStatement.getSelectBody() instanceof SetOperationList) {
                    SetOperationList setOperationList = (SetOperationList) selectStatement.getSelectBody();
                    List<OrderByElement> orderByElements = setOperationList.getOrderByElements();
                    List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
                    setOperationList.setOrderByElements(orderByElementsReturn);
                    return setOperationList.toString();
                } else if (selectStatement.getSelectBody() instanceof WithItem) {
                    // todo: don't known how to resole
                    return originalSql;
                } else {
                    return originalSql;
                }

            } catch (JSQLParserException e) {
                logger.warn("failed to concat orderBy from IPage, exception=" + e.getMessage());
            }
        }
        return originalSql;
    }

    private static List<OrderByElement> addOrderByElements(List<OrderItemDTO> orderList, List<OrderByElement> orderByElements) {
        orderByElements = CollectionUtils.isEmpty(orderByElements) ? new ArrayList<>(orderList.size()) : orderByElements;
        List<OrderByElement> orderByElementList = orderList.stream()
                .filter(item -> StringUtils.isNotBlank(item.getColumn()))
                .map(item -> {
                    OrderByElement element = new OrderByElement();
                    element.setExpression(new Column(item.getColumn()));
                    element.setAsc(item.isAsc());
                    element.setAscDescPresent(true);
                    return element;
                }).collect(Collectors.toList());
        orderByElements.addAll(orderByElementList);
        return orderByElements;
    }
}
